{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Count 1s in a sequence using LSTM and GRU cells\n",
    "\n",
    "In this example, we'll use PyTorch to implement LSTM and GRU cells (but we won't use the default implementations). We'll use them to count the number of 1s in a binary sequence of 0s and 1s."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start with the imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "import torch\n",
    "import typing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll define some parameters of the networks and the training process:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHS = 10  # training epochs\n",
    "TRAINING_SAMPLES = 10000  # training dataset size\n",
    "BATCH_SIZE = 16  # mini batch size\n",
    "TEST_SAMPLES = 1000  # test dataset size\n",
    "SEQUENCE_LENGTH = 20  # binary sequence length\n",
    "HIDDEN_UNITS = 20  # hidden units of the LSTM cell"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's implement a basic `LSTMCell` as a subclass of `torch.nn.Module`. The cell implementaion process only a single element of the sequence. Later, we'll include it in a larger recurrent module for processing whole sequences."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTMCell(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, input_size: int, hidden_size: int):\n",
    "        \"\"\"\n",
    "        :param input_size: input vector size\n",
    "        :param hidden_size: cell state vector size\n",
    "        \"\"\"\n",
    "\n",
    "        super(LSTMCell, self).__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "\n",
    "        # combine all gates in a single matrix multiplication\n",
    "        self.x_fc = torch.nn.Linear(input_size, 4 * hidden_size)\n",
    "        self.h_fc = torch.nn.Linear(hidden_size, 4 * hidden_size)\n",
    "\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        \"\"\"Xavier initialization \"\"\"\n",
    "        size = math.sqrt(3.0 / self.hidden_size)\n",
    "        for weight in self.parameters():\n",
    "            weight.data.uniform_(-size, size)\n",
    "\n",
    "    def forward(self,\n",
    "                x_t: torch.Tensor,\n",
    "                hidden: typing.Tuple[torch.Tensor, torch.Tensor] = (None, None)) \\\n",
    "            -> typing.Tuple[torch.Tensor, torch.Tensor]:\n",
    "        h_t_1, c_t_1 = hidden  # t_1 is equivalent to t-1\n",
    "\n",
    "        # in case of more than 2-dimensional input\n",
    "        # flatten the tensor (similar to numpy.reshape)\n",
    "        x_t = x_t.view(-1, x_t.size(1))\n",
    "        h_t_1 = h_t_1.view(-1, h_t_1.size(1))\n",
    "        c_t_1 = c_t_1.view(-1, c_t_1.size(1))\n",
    "\n",
    "        # compute the activations of all gates simultaneously\n",
    "        gates = self.x_fc(x_t) + self.h_fc(h_t_1)\n",
    "\n",
    "        # split the input to the 4 separate gates\n",
    "        i_t, f_t, candidate_c_t, o_t = gates.chunk(4, 1)\n",
    "\n",
    "        # compute the activations for all gates\n",
    "        i_t, f_t, candidate_c_t, o_t = \\\n",
    "            i_t.sigmoid(), f_t.sigmoid(), candidate_c_t.tanh(), o_t.sigmoid()\n",
    "\n",
    "        # choose new state based on the input and forget gates\n",
    "        c_t = torch.mul(f_t, c_t_1) + torch.mul(i_t, candidate_c_t)\n",
    "\n",
    "        # compute the cell output\n",
    "        h_t = torch.mul(o_t, c_t.tanh())\n",
    "\n",
    "        return h_t, c_t"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's implement `LSTMModel`, which contains one `LSTMCell`, but handles a whole input sequence. At each step of the sequence, `LSTMModel` returns the model prediction based on the whole sequence up to the current step:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTMModel(torch.nn.Module):\n",
    "    \"\"\"LSTM model with a single output layer connected to the lstm cell output\"\"\"\n",
    "\n",
    "    def __init__(self, input_size, hidden_size, output_size):\n",
    "        super(LSTMModel, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "\n",
    "        # Our own LSTM implementation\n",
    "        self.lstm = LSTMCell(input_size, hidden_size)\n",
    "\n",
    "        # Fully-connected output layer\n",
    "        self.fc = torch.nn.Linear(hidden_size, output_size)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Start with empty network output and cell state to initialize the sequence\n",
    "        c_t = torch.zeros((x.size(0), self.hidden_size)).to(x.device)\n",
    "        h_t = torch.zeros((x.size(0), self.hidden_size)).to(x.device)\n",
    "\n",
    "        # Iterate over all sequence elements across all sequences of the mini-batch\n",
    "        for seq in range(x.size(1)):\n",
    "            h_t, c_t = self.lstm(x[:, seq, :], (h_t, c_t))\n",
    "\n",
    "        # Final output layer\n",
    "        return self.fc(h_t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll follow the same blueprint to implement `GRUCell` and `GRUModel` respecively. Let's start with the `GRUCell` implementation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GRUCell(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, input_size: int, hidden_size: int):\n",
    "        \"\"\"\n",
    "        :param input_size: input vector size\n",
    "        :param hidden_size: cell state vector size\n",
    "        \"\"\"\n",
    "\n",
    "        super(GRUCell, self).__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "\n",
    "        # x to reset gate r\n",
    "        self.x_r_fc = torch.nn.Linear(input_size, hidden_size)\n",
    "\n",
    "        # x to update gate z\n",
    "        self.x_z_fc = torch.nn.Linear(input_size, hidden_size)\n",
    "\n",
    "        # x to candidate state h'(t)\n",
    "        self.x_h_fc = torch.nn.Linear(input_size, hidden_size)\n",
    "\n",
    "        # network output/state h(t-1) to reset gate r\n",
    "        self.h_r_fc = torch.nn.Linear(hidden_size, hidden_size)\n",
    "\n",
    "        # network output/state h(t-1) to update gate z\n",
    "        self.h_z_fc = torch.nn.Linear(hidden_size, hidden_size)\n",
    "\n",
    "        # network state h(t-1) passed through the reset gate r towards candidate state h(t)\n",
    "        self.hr_h_fc = torch.nn.Linear(hidden_size, hidden_size)\n",
    "\n",
    "    def forward(self,\n",
    "                x_t: torch.Tensor,\n",
    "                h_t_1: torch.Tensor = None) \\\n",
    "            -> torch.Tensor:\n",
    "\n",
    "        # compute update gate vector\n",
    "        z_t = torch.sigmoid(self.x_z_fc(x_t) + self.h_z_fc(h_t_1))\n",
    "\n",
    "        # compute reset gate vector\n",
    "        r_t = torch.sigmoid(self.x_r_fc(x_t) + self.h_r_fc(h_t_1))\n",
    "\n",
    "        # compute candidate state\n",
    "        candidate_h_t = torch.tanh(self.x_h_fc(x_t) + self.hr_h_fc(torch.mul(r_t, h_t_1)))\n",
    "\n",
    "        # compute cell output\n",
    "        h_t = torch.mul(z_t, h_t_1) + torch.mul(1 - z_t, candidate_h_t)\n",
    "\n",
    "        return h_t"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll continue with the `GRUModel` class for processing whole sequences:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GRUModel(torch.nn.Module):\n",
    "    \"\"\"LSTM model with a single output layer connected to the lstm cell output\"\"\"\n",
    "\n",
    "    def __init__(self, input_size, hidden_size, output_size):\n",
    "        super(GRUModel, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "\n",
    "        # Our own GRU implementation\n",
    "        self.gru = GRUCell(input_size, hidden_size)\n",
    "\n",
    "        # Fully-connected output layer\n",
    "        self.fc = torch.nn.Linear(hidden_size, output_size)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Start with empty network output and cell state to initialize the sequence\n",
    "        h_t = torch.zeros((x.size(0), self.hidden_size)).to(x.device)\n",
    "\n",
    "        # Iterate over all sequence elements across all sequences of the mini-batch\n",
    "        for seq in range(x.size(1)):\n",
    "            h_t = self.gru(x[:, seq, :], h_t)\n",
    "\n",
    "        # Final output layer\n",
    "        return self.fc(h_t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll implement the `generate_dataset` function, which generates a total of `samples` binary sequences, each with length of `sequence_length`. The function returns the sequence and it's numeric label, which indicates the number of 1s in that sequence:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_dataset(sequence_length: int, samples: int):\n",
    "    \"\"\"\n",
    "    Generate training/testing datasets\n",
    "    :param sequence_length: length of the binary sequence\n",
    "    :param samples: number of samples\n",
    "    \"\"\"\n",
    "\n",
    "    sequences = list()\n",
    "    labels = list()\n",
    "    for i in range(samples):\n",
    "        a = np.random.randint(sequence_length) / sequence_length\n",
    "        sequence = list(np.random.choice(2, sequence_length, p=[a, 1 - a]))\n",
    "        sequences.append(sequence)\n",
    "        labels.append(int(np.sum(sequence)))\n",
    "\n",
    "    sequences = np.array(sequences)\n",
    "    labels = np.array(labels, dtype=np.int8)\n",
    "\n",
    "    result = torch.utils.data.TensorDataset(\n",
    "        torch.from_numpy(sequences).float().unsqueeze(-1),\n",
    "        torch.from_numpy(labels).float())\n",
    "\n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll continue with the implementation of the training procedure for either `LSTMModel` or `GRUModel`. This procedure is generic and doesn't differ from similar procedures for feed-forward networks. The recurrence part is handled by PyTorch's _autodiff_ functionality within `LSTMMOdel` and `GRUModel`: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, loss_function, optimizer, data_loader):\n",
    "    # set model to training mode\n",
    "    model.train()\n",
    "\n",
    "    current_loss = 0.0\n",
    "    current_acc = 0\n",
    "\n",
    "    # iterate over the training data\n",
    "    for i, (inputs, labels) in enumerate(data_loader):\n",
    "        # send the input/labels to the GPU\n",
    "        inputs = inputs.to(device)\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        # zero the parameter gradients\n",
    "        model.zero_grad()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        with torch.set_grad_enabled(True):\n",
    "            # forward\n",
    "            outputs = model(inputs).squeeze()\n",
    "            loss = loss_function(outputs, labels)\n",
    "\n",
    "            # backward\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        # statistics\n",
    "        current_loss += loss.item() * inputs.size(0)\n",
    "        current_acc += torch.sum(outputs.round() == labels.data)\n",
    "\n",
    "    total_loss = current_loss / len(data_loader.dataset)\n",
    "    total_acc = current_acc.double() / len(data_loader.dataset)\n",
    "\n",
    "    print('Train Loss: {:.4f}; Accuracy: {:.4f}'.format(total_loss, total_acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, we'll implement the testing procedure. As with the training, this is a generic procedure which is similar to the one for feed-forward networks, since the sequence processing is handled internally by `LSTMMOdel` and `GRUModel`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_model(model, loss_function, data_loader):\n",
    "    # set model in evaluation mode\n",
    "    model.eval()\n",
    "\n",
    "    current_loss = 0.0\n",
    "    current_acc = 0\n",
    "\n",
    "    # iterate over  the validation data\n",
    "    for i, (inputs, labels) in enumerate(data_loader):\n",
    "        # send the input/labels to the GPU\n",
    "        inputs = inputs.to(device)\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        # forward\n",
    "        with torch.set_grad_enabled(False):\n",
    "            outputs = model(inputs).squeeze()\n",
    "            loss = loss_function(outputs, labels)\n",
    "\n",
    "        # statistics\n",
    "        current_loss += loss.item() * inputs.size(0)\n",
    "        current_acc += torch.sum(outputs.round() == labels.data)\n",
    "\n",
    "    total_loss = current_loss / len(data_loader.dataset)\n",
    "    total_acc = current_acc.double() / len(data_loader.dataset)\n",
    "\n",
    "    print('Test Loss: {:.4f}; Accuracy: {:.4f}'.format(total_loss, total_acc))\n",
    "\n",
    "    return total_loss, total_acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now put it all together. We'll start by instantiting the `device`, the `train_loader`, and the `test_loader`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select device\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Generate training and testing datasets\n",
    "train = generate_dataset(SEQUENCE_LENGTH, TRAINING_SAMPLES)\n",
    "train_loader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)\n",
    "\n",
    "test = generate_dataset(SEQUENCE_LENGTH, TEST_SAMPLES)\n",
    "test_loader = torch.utils.data.DataLoader(test, batch_size=BATCH_SIZE, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's instantiate an `LSTMModel`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate LSTM model\n",
    "# input of size 1 for digit of the sequence\n",
    "# number of hidden units\n",
    "# regression model output size (number of ones)\n",
    "model = LSTMModel(input_size=1,\n",
    "                  hidden_size=HIDDEN_UNITS,\n",
    "                  output_size=1)\n",
    "\n",
    "# Transfer the model to the GPU\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll instantiate the training framework components..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# loss function (we use MSELoss because of the regression)\n",
    "loss_function = torch.nn.MSELoss()\n",
    "\n",
    "# Adam optimizer\n",
    "optimizer = torch.optim.Adam(model.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "... and we'll run the training:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "Train Loss: 59.7466; Accuracy: 0.0510\n",
      "Test Loss: 19.6232; Accuracy: 0.0700\n",
      "Epoch 2/10\n",
      "Train Loss: 6.7320; Accuracy: 0.4826\n",
      "Test Loss: 1.6951; Accuracy: 0.6920\n",
      "Epoch 3/10\n",
      "Train Loss: 0.8296; Accuracy: 0.8031\n",
      "Test Loss: 0.2831; Accuracy: 0.8570\n",
      "Epoch 4/10\n",
      "Train Loss: 0.1729; Accuracy: 0.8917\n",
      "Test Loss: 0.0736; Accuracy: 0.9290\n",
      "Epoch 5/10\n",
      "Train Loss: 0.0543; Accuracy: 0.9333\n",
      "Test Loss: 0.0257; Accuracy: 1.0000\n",
      "Epoch 6/10\n",
      "Train Loss: 0.0272; Accuracy: 0.9986\n",
      "Test Loss: 0.0124; Accuracy: 1.0000\n",
      "Epoch 7/10\n",
      "Train Loss: 0.0155; Accuracy: 0.9998\n",
      "Test Loss: 0.0125; Accuracy: 0.9980\n",
      "Epoch 8/10\n",
      "Train Loss: 0.0117; Accuracy: 0.9999\n",
      "Test Loss: 0.0102; Accuracy: 1.0000\n",
      "Epoch 9/10\n",
      "Train Loss: 0.0101; Accuracy: 0.9998\n",
      "Test Loss: 0.0052; Accuracy: 1.0000\n",
      "Epoch 10/10\n",
      "Train Loss: 0.0077; Accuracy: 1.0000\n",
      "Test Loss: 0.0041; Accuracy: 1.0000\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(EPOCHS):\n",
    "    print('Epoch {}/{}'.format(epoch + 1, EPOCHS))\n",
    "\n",
    "    train_model(model, loss_function, optimizer, train_loader)\n",
    "    test_model(model, loss_function, test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Unlike regular RNN, we can see that the LSTM network doesn't suffer from exploding gradients even with sequence length 20. <br />\n",
    "Let's do the same experiment with `GRUModel`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "Train Loss: 55.0390; Accuracy: 0.0703\n",
      "Test Loss: 16.3110; Accuracy: 0.2040\n",
      "Epoch 2/10\n",
      "Train Loss: 7.0288; Accuracy: 0.5343\n",
      "Test Loss: 2.1039; Accuracy: 0.7260\n",
      "Epoch 3/10\n",
      "Train Loss: 1.0462; Accuracy: 0.8011\n",
      "Test Loss: 0.3727; Accuracy: 0.8660\n",
      "Epoch 4/10\n",
      "Train Loss: 0.2059; Accuracy: 0.8938\n",
      "Test Loss: 0.0754; Accuracy: 0.9280\n",
      "Epoch 5/10\n",
      "Train Loss: 0.0504; Accuracy: 0.9327\n",
      "Test Loss: 0.0237; Accuracy: 1.0000\n",
      "Epoch 6/10\n",
      "Train Loss: 0.0162; Accuracy: 0.9999\n",
      "Test Loss: 0.0098; Accuracy: 1.0000\n",
      "Epoch 7/10\n",
      "Train Loss: 0.0062; Accuracy: 1.0000\n",
      "Test Loss: 0.0045; Accuracy: 1.0000\n",
      "Epoch 8/10\n",
      "Train Loss: 0.0032; Accuracy: 1.0000\n",
      "Test Loss: 0.0039; Accuracy: 1.0000\n",
      "Epoch 9/10\n",
      "Train Loss: 0.0022; Accuracy: 1.0000\n",
      "Test Loss: 0.0018; Accuracy: 1.0000\n",
      "Epoch 10/10\n",
      "Train Loss: 0.0018; Accuracy: 1.0000\n",
      "Test Loss: 0.0017; Accuracy: 1.0000\n"
     ]
    }
   ],
   "source": [
    "model = GRUModel(input_size=1,\n",
    "                 hidden_size=HIDDEN_UNITS,\n",
    "                 output_size=1)\n",
    "\n",
    "# Transfer the model to the GPU\n",
    "model = model.to(device)\n",
    "\n",
    "# loss function (we use MSELoss because of the regression)\n",
    "loss_function = torch.nn.MSELoss()\n",
    "\n",
    "# Adam optimizer\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "\n",
    "# Train\n",
    "for epoch in range(EPOCHS):\n",
    "    print('Epoch {}/{}'.format(epoch + 1, EPOCHS))\n",
    "\n",
    "    train_model(model, loss_function, optimizer, train_loader)\n",
    "    test_model(model, loss_function, test_loader)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `GRUModel` achieved 100% accuracy even sooner than the `LSTMModel`. However, this is a toy dataset and we cannot use it as a comparison for the real-world performance of the 2 models."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
